from dataclasses import dataclass

@dataclass
class ModelArgs:
    d_model: int = 768
    time_n_layers: int = 8
    time_n_heads: int = 16

    out_features: int = 2
    channel_n_heads: int = 16

    norm_eps: float = 1e-5
    dropout: float = 0.1

    mask_ratio: float = 0.5
    attention: str = 'eager'  # 'flash_attention_2', 'eager', 'sdpa'
    local_path: str = '/data/pretrain_models/bert_cn/'